{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "handed-robinson",
   "metadata": {},
   "source": [
    "# Introduction\n",
    "\n",
    "This document concerns how to use the GluonTS implementation of [Level Set Forecaster](https://papers.nips.cc/paper/2021/hash/32b127307a606effdcc8e51f60a45922-Abstract.html). Level Set Forecaster (LSF) is an algorithm for a tabular point prediction algorithm into a probabilistic one.\n",
    "\n",
    "A tabular point prediction algorithm is an algorithm that takes in tabular data and, at inference, outputs real numbers. Examples include: linear regression, random forests, XGBoost, etc. A tabular probabilistic prediction algorithm is one that at inference outputs an estimate of the conditional distribution.\n",
    "\n",
    "The high-level description of the LSF algorithm is that it bins the training data in such a manner that the predictions of the feature vectors in each bin are similar. At inference, it then uses the empirical distribution of true values for the bin associated with the new feature vector. We invite those who are interested to read the details and the theoretical guarantees in the [NeurIPS paper](https://papers.nips.cc/paper/2021/hash/32b127307a606effdcc8e51f60a45922-Abstract.html). \n",
    "\n",
    "The better your underlying model, the better performing its LSF wrapping will be."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "numerous-marks",
   "metadata": {},
   "source": [
    "# Retrofitting an Existing Model\n",
    "\n",
    "Assume that you have trained a tabular point prediction algorithm and named it `underlying_model`. In order to wrap it with LSF, you have to feed into LSF the a dataset from the same distribution that `underlying_model` was trained on:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "destroyed-champion",
   "metadata": {},
   "source": [
    "```\n",
    "from gluonts.model.rotbaum._model import LSF\n",
    "model = LSF(model=underlying_model, min_bin_size=100)\n",
    "model.fit(X_train, y_train, model_is_already_trained=True)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aboriginal-arkansas",
   "metadata": {},
   "source": [
    "It is possible to provide LSF the exact same dataset that `underlying_model` was trained on, but it is not necessary. There is also no assumption that the dataset provided to LSF is the same size as the one provided to `underlying_model`.\n",
    "\n",
    "The hyperparameter `min_bin_size` is by default `100`. While the NeurIPS paper ensures consistency, under certain conditions, with an increasing `min_bin_size` (of order of magnitude `(ln(n))^2`), in practice we have found that keeping this hyperparameter at `100` works in a surprisingly wide variety of datasets.\n",
    "\n",
    "Note that some `underlying_model`s assume that `X_train` is a list of lists (or numpy array of numpy arrays), whereas some assume that it is a pandas dataframe. By default LSF works with the former. To apply LSF to the latter, simply set `x_train_is_dataframe=True` in `model.fit`.\n",
    "\n",
    "Note that while LSF is native to the tabular use case, it is possible to apply it with neural networks as well. To that end, LSF can come in at the embedding level of the architecture. To be a little more explicit, if f1 is the portion of the neural network that embeds the data into a fixed dimensional space, and the remainder of the network is f2, then after training you would cache pairs of embeddings and true target values, and later feed them into LSF together with f2 as the base algorithm. At inference, rather than feeding in the raw data, you would feed the data into f1 to obtain a fixed length feature vector, and only then feed it into the LSF."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "altered-hepatitis",
   "metadata": {},
   "source": [
    "# Letting LSF Train the Tabular Model As Well\n",
    "\n",
    "Letting `model_is_already_trained=False` (as is the default), LSF will train `underlying_model` first and only then wrap it with LSF.\n",
    "\n",
    "For example:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dimensional-orange",
   "metadata": {},
   "source": [
    "```\n",
    "from gluonts.model.rotbaum._model import LSF\n",
    "from xgboost import XGBModel\n",
    "underlying_model = XGBModel(\n",
    "         base_score=0.5, booster='gbtree', colsample_bylevel=1,\n",
    "         colsample_bynode=1, colsample_bytree=1, gamma=0, gpu_id=-1,\n",
    "         importance_type='gain', interaction_constraints='',\n",
    "         learning_rate=0.300000012, max_delta_step=0, max_depth=3,\n",
    "         min_child_weight=1, monotone_constraints='()',\n",
    "         n_estimators=50, n_jobs=-1, num_parallel_tree=1,\n",
    "         objective='reg:squarederror', random_state=0, reg_alpha=0,\n",
    "         reg_lambda=1, scale_pos_weight=1, subsample=1, tree_method='exact',\n",
    "         validate_parameters=1, verbosity=1)\n",
    "model = LSF(model=underlying_model)\n",
    "model.fit(X_train, y_train)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "hearing-vermont",
   "metadata": {},
   "source": [
    "In case you want LSF to use less data to create the bins than the `underyling_model` uses to train, simply set `max_sample_size`. LSF will then sample `min(max_sample_size, len(X_train)` many data points from the training data without replacement for the purpose of creating the bins."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "funded-plymouth",
   "metadata": {},
   "source": [
    "# LSF Wrapping XGBoost\n",
    "\n",
    "By default, without specifying an `underlying_model`, LSF will wrap XGBoost with some default parameters:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "nominated-skirt",
   "metadata": {},
   "source": [
    "```\n",
    "from gluonts.model.rotbaum._model import LSF\n",
    "model.fit(X_train, y_train)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "capable-boating",
   "metadata": {},
   "source": [
    "# Inference\n",
    "\n",
    "If we want to estimate the conditional quantile for a specific quantile (between 0 and 1), simply query it thusly:\n",
    "```\n",
    "model.predict(X_test, quantile)\n",
    "```\n",
    "\n",
    "One can also retrieve the bin in its entirety, which in turn can be interpreted as an estimated sampling from the conditional distribution. To be precise `model.estimate_dist(X_test)` will output, in pseudocode, `[list of true values whose associated feature vectors are in the same bin as x for x in X_test]`.\n",
    "\n",
    "One can then plot histograms of the estimated conditional distributions:\n",
    "```\n",
    "from matplotlib import pyplot as plt\n",
    "plt.hist(model.estimate_dist(X_test)[0])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "whole-start",
   "metadata": {},
   "source": [
    "# Quick Synthetic Example to Get You Started"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "casual-fever",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/hashilaf/repos/gluonts-lightgbmfork2/gluon-ts/src/gluonts/json.py:101: UserWarning: Using `json`-module for json-handling. Consider installing one of `orjson`, `ujson` to speed up serialization and deserialization.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from gluonts.model.rotbaum._model import LSF\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "great-retailer",
   "metadata": {},
   "outputs": [],
   "source": [
    "def weighted_quantile_loss(true, pred, quantile):\n",
    "    denom = sum(np.abs(true))\n",
    "    num = sum(\n",
    "        [\n",
    "            (1 - quantile) * abs(y_hat - y)\n",
    "            if y_hat > y\n",
    "            else quantile * abs(y_hat - y)\n",
    "            for y_hat, y in zip(pred, true)\n",
    "        ]\n",
    "    )\n",
    "    if denom != 0:\n",
    "        return 2 * num / denom\n",
    "    else:\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "changing-creation",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = [\n",
    "    [np.random.normal(0, 1), np.random.normal(0, 1)] for i in range(10000)\n",
    "]\n",
    "y_train = [i + j + np.random.normal(0, abs(i + j)) for i, j in X_train]\n",
    "X_test = [\n",
    "    [np.random.normal(0, 1), np.random.normal(0, 1)] for i in range(10000)\n",
    "]\n",
    "y_test = [i + j + np.random.normal(0, abs(i + j)) for i, j in X_test]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "protective-freight",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LinearRegression()"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "underlying_model = LinearRegression()\n",
    "underlying_model.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "quarterly-nowhere",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LSF(model=underlying_model, min_bin_size=100)\n",
    "model.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "catholic-milwaukee",
   "metadata": {},
   "outputs": [],
   "source": [
    "P10 = model.predict(X_test, 0.1)\n",
    "P50 = model.predict(X_test, 0.5)\n",
    "P90 = model.predict(X_test, 0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "amazing-florida",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.30368475167009445\n",
      "0.6872476043155556\n",
      "0.3032539971588588\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(array([ 4.,  4., 14., 19., 28., 14.,  9.,  6.,  1.,  1.]),\n",
       " array([-0.99932272, -0.48932945,  0.02066383,  0.53065711,  1.04065038,\n",
       "         1.55064366,  2.06063693,  2.57063021,  3.08062348,  3.59061676,\n",
       "         4.10061004]),\n",
       " <BarContainer object of 10 artists>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAALZklEQVR4nO3dX4hc9RnG8eepsViqpZEMIWjoFhEhFIxlSS0pxfqnRFOqFloaqORCiBcKEYSy9ab2LoWqvSlCrEGh1iKoKI3UpjYggthubKrRVBSJNCFmV6SoNy0xTy/2pC7bXWey8+f47nw/sOzMmbP7e49mvxzOzsw6iQAA9Xym7QEAAMtDwAGgKAIOAEURcAAoioADQFGrRrnYmjVrMjExMcolAaC8AwcOvJuks3D7SAM+MTGh6enpUS4JAOXZfnux7VxCAYCiCDgAFEXAAaAoAg4ARRFwACiKgANAUQQcAIoi4ABQFAEHgKJG+kpMoJuJqb2trHtk19ZW1gX6wRk4ABRFwAGgKAIOAEURcAAoioADQFEEHACKIuAAUBQBB4CiCDgAFEXAAaAoAg4ARRFwACiKgANAUQQcAIoi4ABQFAEHgKIIOAAURcABoKiuAbe93vZ+26/ZftX2zmb7XbaP2T7YfFw3/HEBAKf18jcxT0q6I8lLts+TdMD2vuaxe5P8YnjjAQCW0jXgSY5LOt7c/sD2YUkXDHswAMAnO6Nr4LYnJF0m6cVm0222X7a9x/bqJb5mh+1p29Ozs7P9TQsA+J+eA277XEmPSbo9yfuS7pN0kaSNmjtDv3uxr0uyO8lkkslOp9P/xAAAST0G3PbZmov3w0kel6QkJ5J8lOSUpPslbRremACAhXp5FoolPSDpcJJ75m1fN2+3GyUdGvx4AICl9PIslM2SbpL0iu2DzbY7JW2zvVFSJB2RdMsQ5gMALKGXZ6E8L8mLPPT04McBAPSKV2ICQFEEHACKIuAAUBQBB4CiCDgAFEXAAaAoAg4ARRFwACiKgANAUQQcAIoi4ABQFAEHgKIIOAAURcABoCgCDgBFEXAAKIqAA0BRBBwAiiLgAFAUAQeAogg4ABRFwAGgKAIOAEURcAAoioADQFEEHACKIuAAUBQBB4Ciugbc9nrb+22/ZvtV2zub7efb3mf7jebz6uGPCwA4rZcz8JOS7kiyQdLlkm61vUHSlKRnk1ws6dnmPgBgRLoGPMnxJC81tz+QdFjSBZKul/RQs9tDkm4Y0owAgEWc0TVw2xOSLpP0oqS1SY43D70jae0SX7PD9rTt6dnZ2X5mBQDM03PAbZ8r6TFJtyd5f/5jSSIpi31dkt1JJpNMdjqdvoYFAHysp4DbPltz8X44yePN5hO21zWPr5M0M5wRAQCL6eVZKJb0gKTDSe6Z99BTkrY3t7dLenLw4wEAlrKqh302S7pJ0iu2Dzbb7pS0S9Kjtm+W9LakHwxlQgDAoroGPMnzkrzEw1cNdhwAQK94JSYAFNXLJRSMmYmpvW2PAKAHnIEDQFEEHACKIuAAUBQBB4CiCDgAFEXAAaAoAg4ARRFwACiKgANAUQQcAIoi4ABQFAEHgKIIOAAURcABoCgCDgBFEXAAKIqAA0BRBBwAiiLgAFAUAQeAogg4ABRFwAGgKAIOAEURcAAoioADQFEEHACK6hpw23tsz9g+NG/bXbaP2T7YfFw33DEBAAv1cgb+oKQti2y/N8nG5uPpwY4FAOima8CTPCfpvRHMAgA4A/1cA7/N9svNJZbVS+1ke4ftadvTs7OzfSwHAJhvuQG/T9JFkjZKOi7p7qV2TLI7yWSSyU6ns8zlAAALLSvgSU4k+SjJKUn3S9o02LEAAN0sK+C21827e6OkQ0vtCwAYjlXddrD9iKQrJK2xfVTSTyVdYXujpEg6IumW4Y0IAFhM14An2bbI5geGMAsA4AzwSkwAKKrrGTjaMzG1t+0Rxkab/62P7Nra2tqojTNwACiKgANAUQQcAIoi4ABQFAEHgKIIOAAURcABoCgCDgBFEXAAKIqAA0BRBBwAiiLgAFAUAQeAogg4ABRFwAGgKAIOAEURcAAoioADQFEEHACKIuAAUBQBB4CiCDgAFEXAAaAoAg4ARRFwACiKgANAUV0DbnuP7Rnbh+ZtO9/2PttvNJ9XD3dMAMBCvZyBPyhpy4JtU5KeTXKxpGeb+wCAEeoa8CTPSXpvwebrJT3U3H5I0g2DHQsA0M1yr4GvTXK8uf2OpLVL7Wh7h+1p29Ozs7PLXA4AsFDfv8RMEkn5hMd3J5lMMtnpdPpdDgDQWG7AT9heJ0nN55nBjQQA6MVyA/6UpO3N7e2SnhzMOACAXvXyNMJHJL0g6RLbR23fLGmXpGtsvyHp6uY+AGCEVnXbIcm2JR66asCzAADOAK/EBICiup6BAxiuiam9rax7ZNfWVtbF4HAGDgBFEXAAKIqAA0BRBBwAiiLgAFAUAQeAogg4ABRFwAGgKAIOAEURcAAoioADQFEEHACKIuAAUBQBB4CiCDgAFEXAAaAoAg4ARRFwACiKgANAUQQcAIoi4ABQFH+VHhhTE1N7W1v7yK6tra29knAGDgBFEXAAKIqAA0BRBBwAiurrl5i2j0j6QNJHkk4mmRzEUACA7gbxLJRvJXl3AN8HAHAGuIQCAEX1G/BI+qPtA7Z3LLaD7R22p21Pz87O9rkcAOC0fgP+jSRflXStpFttf3PhDkl2J5lMMtnpdPpcDgBwWl8BT3Ks+Twj6QlJmwYxFACgu2UH3PbnbZ93+rakb0s6NKjBAACfrJ9noayV9ITt09/nt0n+MJCpAABdLTvgSd6SdOkAZwEAnAGeRggARZV5O9k23/oSAD6NOAMHgKIIOAAURcABoCgCDgBFEXAAKIqAA0BRBBwAiiLgAFAUAQeAogg4ABRFwAGgKAIOAEURcAAoioADQFEEHACKIuAAUBQBB4CiCDgAFEXAAaAoAg4ARRFwACiqzF+lB7ByTEztbXuEkTuya+vAvydn4ABQFAEHgKIIOAAURcABoKi+Am57i+3Xbb9pe2pQQwEAult2wG2fJelXkq6VtEHSNtsbBjUYAOCT9XMGvknSm0neSvIfSb+TdP1gxgIAdNPP88AvkPTPefePSvrawp1s75C0o7n7oe3Xl7neGknvLvNrKxqn4x2nY5U43pVu0eP1z/v6nl9abOPQX8iTZLek3f1+H9vTSSYHMFIJ43S843SsEse70o3yePu5hHJM0vp59y9stgEARqCfgP9V0sW2v2z7s5J+KOmpwYwFAOhm2ZdQkpy0fZukZySdJWlPklcHNtn/6/syTDHjdLzjdKwSx7vSjex4nWRUawEABohXYgJAUQQcAIoqFXDb37f9qu1Ttlfk05LG6e0JbO+xPWP7UNuzjILt9bb3236t+Xe8s+2ZhsX2Obb/YvvvzbH+rO2ZRsH2Wbb/Zvv3o1ivVMAlHZL0PUnPtT3IMIzh2xM8KGlL20OM0ElJdyTZIOlySbeu4P+//5Z0ZZJLJW2UtMX25e2ONBI7JR0e1WKlAp7kcJLlvpKzgrF6e4Ikz0l6r+05RiXJ8SQvNbc/0NwP+gXtTjUcmfNhc/fs5mNFP2PC9oWStkr69ajWLBXwMbDY2xOsyB/wcWd7QtJlkl5seZShaS4nHJQ0I2lfkhV7rI1fSvqxpFOjWvBTF3Dbf7J9aJGPFXsmivFi+1xJj0m6Pcn7bc8zLEk+SrJRc6/S3mT7Ky2PNDS2vyNpJsmBUa77qfujxkmubnuGFvH2BCuc7bM1F++Hkzze9jyjkORftvdr7vcdK/UX1pslfdf2dZLOkfQF279J8qNhLvqpOwMfc7w9wQpm25IekHQ4yT1tzzNMtju2v9jc/pykayT9o9WhhijJT5JcmGRCcz+3fx52vKViAbd9o+2jkr4uaa/tZ9qeaZCSnJR0+u0JDkt6dMhvT9Aq249IekHSJbaP2r657ZmGbLOkmyRdaftg83Fd20MNyTpJ+22/rLkTk31JRvLUunHCS+kBoKhSZ+AAgI8RcAAoioADQFEEHACKIuAAUBQBB4CiCDgAFPVfMtl9qX81CjwAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(weighted_quantile_loss(y_test, P10, 0.1))\n",
    "print(weighted_quantile_loss(y_test, P50, 0.5))\n",
    "print(weighted_quantile_loss(y_test, P90, 0.9))\n",
    "plt.hist(model.estimate_dist([[0.5, 0.5]])[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "standing-trouble",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[2.288148404849137,\n",
       " 1.574696116254158,\n",
       " 0.6568102103240447,\n",
       " 2.2993520742685227,\n",
       " 0.39109241176636667,\n",
       " 1.7025821798202085,\n",
       " 0.6855239431277019,\n",
       " 1.0452498711450835,\n",
       " 1.512439052988055,\n",
       " 0.8593260692605851,\n",
       " -0.6962483603681517,\n",
       " 0.139817808528079,\n",
       " 1.64273291259174,\n",
       " 1.1756085357251278,\n",
       " 1.2286820644502852,\n",
       " 2.327047163989284,\n",
       " 0.44771284640985787,\n",
       " 1.467409266835301,\n",
       " 0.9302439674991858,\n",
       " 1.03700532459507,\n",
       " 0.008129538932130087,\n",
       " 1.1368726098692654,\n",
       " 1.6079078731856582,\n",
       " -0.2317150940367927,\n",
       " 0.675471791654082,\n",
       " 2.0313939483686605,\n",
       " 0.8312038376464166,\n",
       " 1.645705787029204,\n",
       " 0.48299689048994443,\n",
       " 1.579827037650825,\n",
       " 1.165238373622516,\n",
       " 1.190270554221262,\n",
       " 1.4542366173260606,\n",
       " 1.2644800116574475,\n",
       " 1.4307864289204961,\n",
       " 0.7068177077124251,\n",
       " 0.13010745707289473,\n",
       " 1.2251120802386763,\n",
       " 0.3693318284416732,\n",
       " 1.1352689844230193,\n",
       " 1.9552669587197022,\n",
       " 2.4357911384654534,\n",
       " -0.0026836468661575186,\n",
       " 1.2065815093215613,\n",
       " 2.0651693948249856,\n",
       " 0.6721360437208823,\n",
       " 1.0491761310070071,\n",
       " 3.2172861439696208,\n",
       " 1.4453451621110442,\n",
       " 2.97958136611339,\n",
       " 2.1761274430887587,\n",
       " 0.3963338828779327,\n",
       " 2.267448034530609,\n",
       " 1.3228653106456072,\n",
       " 1.1575712211495195,\n",
       " 0.26233424616648937,\n",
       " 1.8000744038081407,\n",
       " 2.8989967920415873,\n",
       " 1.454721393152594,\n",
       " 1.031930958227963,\n",
       " 0.8561592385698967,\n",
       " 1.5360408249685844,\n",
       " 1.3960998629765138,\n",
       " 1.0472666019682473,\n",
       " 1.354787671331422,\n",
       " 2.71855795548974,\n",
       " 1.1237632317683988,\n",
       " 4.100610036246034,\n",
       " 0.15966831658768088,\n",
       " -0.6382599558124438,\n",
       " 0.36119013986411597,\n",
       " 1.5507273097237748,\n",
       " 0.7619370183841324,\n",
       " 1.5985185902942805,\n",
       " 0.1754941535388811,\n",
       " 0.4230780642143368,\n",
       " 0.6177496086652821,\n",
       " 1.6052831235499059,\n",
       " 2.65428516371382,\n",
       " 2.4405959721339556,\n",
       " 1.3484173092663088,\n",
       " 0.3135162527573676,\n",
       " 1.4690866663891833,\n",
       " 1.029260031206032,\n",
       " 0.46131889403271176,\n",
       " 2.3861704901341687,\n",
       " -0.9993227221284857,\n",
       " 1.659012249613435,\n",
       " 1.9533371600879148,\n",
       " 0.646100948352357,\n",
       " 0.876088514976303,\n",
       " 2.830688738531131,\n",
       " 0.8617456612312852,\n",
       " -0.7929279473938164,\n",
       " 0.5993728411844842,\n",
       " 1.3607497734619327,\n",
       " 1.3667941564055293,\n",
       " 0.6506879219033462,\n",
       " 2.6083040470905017,\n",
       " -0.29895103619716923]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.estimate_dist([[0.5, 0.5]])[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "typical-adrian",
   "metadata": {},
   "source": [
    "The bins in their entirety are stored in `model.id_to_bins`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "classified-fiction",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([  0.,   0.,   0.,   0.,   0., 100.,   0.,   0.,   0.,   0.]),\n",
       " array([ 99.5,  99.6,  99.7,  99.8,  99.9, 100. , 100.1, 100.2, 100.3,\n",
       "        100.4, 100.5]),\n",
       " <BarContainer object of 10 artists>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOQUlEQVR4nO3df4xlZ13H8feHjiu0CLttx02727qNLWhDAi2TuhWtyBLCD3UbRVKidsWS/UPUQk1kNSaNgT9agwKNpmZDwcVgpaloqyBS19ZKYldnaW233UK3hdJdd7tDoKDwR6l+/eOewmWYYXfuuXdm9tn3K7m55zznPOd8n713P3PmuT8mVYUkqS3PWekCJEnjZ7hLUoMMd0lqkOEuSQ0y3CWpQVMrXQDAmWeeWZs2bVrpMiTphLJ3794vVdX0QttWRbhv2rSJ2dnZlS5Dkk4oSR5fbJvTMpLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBxwz3JB9McjTJvqG205PckeSR7n5d154kNyQ5kOT+JBdPsnhJ0sKO58r9z4HXzmvbAeyuqguA3d06wOuAC7rbduDG8ZQpSVqKY4Z7Vd0NfHle81ZgV7e8C7h8qP3DNXAPsDbJWWOqVZJ0nEb9hOr6qjrcLR8B1nfLG4AnhvY72LUdZp4k2xlc3XPuueeOWIY0WZt2fHzFzv2F696wYufWia/3C6o1+FNOS/5zTlW1s6pmqmpmenrBr0aQJI1o1HB/8tnplu7+aNd+CDhnaL+NXZskaRmNGu63A9u65W3AbUPtV3bvmtkMfHVo+kaStEyOOeee5GbglcCZSQ4C1wLXAbckuQp4HHhTt/sngNcDB4BvAG+ZQM2SpGM4ZrhX1ZsX2bRlgX0LeFvfoiRJ/fgJVUlqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDeoV7knckeTDJviQ3J3lukvOS7ElyIMlHk6wZV7GSpOMzcrgn2QD8FjBTVS8BTgGuAK4H3ltV5wNfAa4aR6GSpOPXd1pmCnhekingVOAw8Crg1m77LuDynueQJC3RyOFeVYeA9wBfZBDqXwX2Ak9V1TPdbgeBDQv1T7I9yWyS2bm5uVHLkCQtoM+0zDpgK3AecDZwGvDa4+1fVTuraqaqZqanp0ctQ5K0gD7TMq8GPl9Vc1X1TeBjwCuAtd00DcBG4FDPGiVJS9Qn3L8IbE5yapIAW4CHgDuBN3b7bANu61eiJGmp+sy572HwwulngAe6Y+0E3glck+QAcAZw0xjqlCQtwdSxd1lcVV0LXDuv+THgkj7HlST14ydUJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ3qFe5J1ia5NcnDSfYnuTTJ6UnuSPJId79uXMVKko5P3yv39wOfrKofAV4K7Ad2ALur6gJgd7cuSVpGI4d7khcClwE3AVTV01X1FLAV2NXttgu4vF+JkqSl6nPlfh4wB3woyb1JPpDkNGB9VR3u9jkCrO9bpCRpafqE+xRwMXBjVV0EfJ15UzBVVUAt1DnJ9iSzSWbn5uZ6lCFJmq9PuB8EDlbVnm79VgZh/2SSswC6+6MLda6qnVU1U1Uz09PTPcqQJM03crhX1RHgiSQv7pq2AA8BtwPburZtwG29KpQkLdlUz/6/CXwkyRrgMeAtDH5g3JLkKuBx4E09zyFJWqJe4V5V9wEzC2za0ue4kqR+/ISqJDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1qHe4Jzklyb1J/r5bPy/JniQHknw0yZr+ZUqSlmIcV+5XA/uH1q8H3ltV5wNfAa4awzkkSUvQK9yTbATeAHygWw/wKuDWbpddwOV9ziFJWrq+V+7vA34H+L9u/Qzgqap6pls/CGxYqGOS7Ulmk8zOzc31LEOSNGzkcE/yM8DRqto7Sv+q2llVM1U1Mz09PWoZkqQFTPXo+wrg55K8Hngu8ALg/cDaJFPd1ftG4FD/MiVJSzHylXtV/W5VbayqTcAVwD9X1S8BdwJv7HbbBtzWu0pJ0pJM4n3u7wSuSXKAwRz8TRM4hyTpe+gzLfMtVXUXcFe3/BhwyTiOK0kajZ9QlaQGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDVo5HBPck6SO5M8lOTBJFd37acnuSPJI939uvGVK0k6Hn2u3J8BfruqLgQ2A29LciGwA9hdVRcAu7t1SdIyGjncq+pwVX2mW/5vYD+wAdgK7Op22wVc3rNGSdISjWXOPckm4CJgD7C+qg53m44A6xfpsz3JbJLZubm5cZQhSer0Dvckzwf+Gnh7VX1teFtVFVAL9auqnVU1U1Uz09PTfcuQJA3pFe5Jvo9BsH+kqj7WNT+Z5Kxu+1nA0X4lSpKWqs+7ZQLcBOyvqj8e2nQ7sK1b3gbcNnp5kqRRTPXo+wrgV4AHktzXtf0ecB1wS5KrgMeBN/WqUJK0ZCOHe1V9Gsgim7eMelxJUn9+QlWSGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkho0kXBP8tokn01yIMmOSZxDkrS4sYd7klOAPwVeB1wIvDnJheM+jyRpcZO4cr8EOFBVj1XV08BfAVsncB5J0iKmJnDMDcATQ+sHgR+bv1OS7cD2bvV/knx2ArVM2pnAl1a6iGV2so15xcab61firMDJ9xjDiTvmH1pswyTC/bhU1U5g50qdfxySzFbVzErXsZxOtjGfbOMFx9yKSUzLHALOGVrf2LVJkpbJJML9P4ALkpyXZA1wBXD7BM4jSVrE2KdlquqZJL8B/CNwCvDBqnpw3OdZJU7oaaURnWxjPtnGC465Camqla5BkjRmfkJVkhpkuEtSgwz3BSS5Osm+JA8meXvX9tIk/5bkgSR/l+QFi/Rdm+TWJA8n2Z/k0mUtfkQ9x/yOrt++JDcnee6yFn+cknwwydEk+4baTk9yR5JHuvt1XXuS3NB9hcb9SS5e5Jgv7/59DnT7Z7nGczzGPeYkpyb5ePf8fjDJdcs5nuMxicd56Di3Dx93Vasqb0M34CXAPuBUBi84/xNwPoN3Af1Ut8+vAe9apP8u4K3d8hpg7UqPaZJjZvChtc8Dz+vWbwF+daXHtMg4LwMuBvYNtf0hsKNb3gFc3y2/HvgHIMBmYM8ix/z3bnu6/V+30uOc5Ji758hPd8trgH9tfcxDx/h54C+Hj7uab165f7cfZfAAf6OqngH+hcGD+iLg7m6fO4BfmN8xyQsZPLFuAqiqp6vqqeUouqeRx9yZAp6XZIrBf/7/mnC9I6mqu4Evz2veyuAHMt395UPtH66Be4C1Sc4a7titv6Cq7qnB//4PD/VfFcY95u45cme3/DTwGQafZVk1xj1mgCTPB64B3j2RoifAcP9u+4CfTHJGklMZ/GQ/B3iQb39Hzi/ynR/UetZ5wBzwoST3JvlAktOWo+ieRh5zVR0C3gN8ETgMfLWqPrUsVY/H+qo63C0fAdZ3ywt9jcaGeX03dO3fa5/VqM+YvyXJWuBngd0TqHHc+o75XcAfAd+YWIVjZrjPU1X7geuBTwGfBO4D/pfBtMSvJ9kL/ADw9ALdpxj8OnhjVV0EfJ3Br4CrWp8xd3OXWxn8YDsbOC3JLy9P5ePVXX2fVO8NHnXM3W9pNwM3VNVjYy9sgpY65iQvA364qv5mYkVNgOG+gKq6qapeXlWXAV8BPldVD1fVa6rq5Qye1I8u0PUgcLCq9nTrtzII+1Wvx5hfDXy+quaq6pvAx4AfX77Ke3vy2V/Du/ujXfvxfI3GIb5zSuJE+aqNPmN+1k7gkap636SKHLM+Y74UmEnyBeDTwIuS3DXRasfAcF9Akh/s7s+lexFlqO05wO8Dfza/X1UdAZ5I8uKuaQvw0LIU3dOoY2YwHbO5exdFGIx5//JUPRa3A9u65W3AbUPtV3bvptjMYLrp8HDHbv1rSTZ3Y79yqP9qNvKYAZK8G3gh8PZlqHVc+jzON1bV2VW1CfgJBhc+r1yesntY6Vd0V+ONwTsAHgL+E9jStV0NfK67Xce3P917NvCJob4vA2aB+4G/Bdat9HiWYcx/ADzMYO7+L4DvX+nxLDLGmxm8LvBNBr9lXQWcwWDO+BEG7xI6vds3DP7ozKPAA8DM0HHuG1qe6cb9KPAnz/4brZbbuMfM4Mq2GPwAv6+7vXWlxznpx3mobRMnyLtl/PoBSWqQ0zKS1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXo/wGToJISxEUu0AAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist([len(model.id_to_bins[id]) for id in model.id_to_bins])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "unique-belarus",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bestenv",
   "language": "python",
   "name": "bestenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
